# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import platform
import tempfile
import unittest
from io import BytesIO

import numpy as np
from op_test import get_device_place
from test_imperative_base import new_program_scope

import paddle
from paddle import base
from paddle.base import framework
from paddle.framework.io_utils import get_value, is_pir_fetch_var, set_value

IMAGE_SIZE = 784


class TestSaveLoadBinaryFormat(unittest.TestCase):
    def setUp(self):
        # enable static graph mode
        paddle.enable_static()
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

    def set_zero(self, prog, place, scope=None):
        if scope is None:
            scope = base.global_scope()
        for var in prog.list_vars():
            if isinstance(var, framework.Parameter) or var.persistable:
                if is_pir_fetch_var(var):
                    continue
                ten = scope.find_var(var.name).get_tensor()
                if ten is not None:
                    ten.set(np.zeros_like(np.array(ten)), place)
                    new_t = np.array(scope.find_var(var.name).get_tensor())
                    self.assertTrue(np.sum(np.abs(new_t)) == 0)

    def replace_save_vars(self, program, dirname):
        def predicate(var):
            return var.persistable

        vars = filter(predicate, program.list_vars())
        for var in vars:
            paddle.save(
                get_value(var),
                os.path.join(dirname, var.name),
                use_binary_format=True,
            )

    def replace_load_vars(self, program, dirname):
        def predicate(var):
            return var.persistable

        var_list = list(filter(predicate, program.list_vars()))
        for var in var_list:
            var_load = paddle.load(os.path.join(dirname, var.name))
            # set var_load to scope
            set_value(var, var_load)

    def test_replace_save_load_vars(self):
        paddle.enable_static()
        with new_program_scope():
            # create network
            x = paddle.static.data(
                name="x", shape=[None, IMAGE_SIZE], dtype='float32'
            )
            z = paddle.static.nn.fc(x, 10, bias_attr=False)
            z = paddle.static.nn.fc(z, 128, bias_attr=False)
            loss = paddle.mean(z)
            place = get_device_place()
            exe = paddle.static.Executor(place)
            exe.run(paddle.static.default_startup_program())
            prog = paddle.static.default_main_program()
            base_map = {}
            for var in prog.list_vars():
                if isinstance(var, framework.Parameter) or var.persistable:
                    if is_pir_fetch_var(var):
                        continue
                    t = np.array(
                        base.global_scope().find_var(var.name).get_tensor()
                    )
                    # make sure all the parameter or optimizer var have been update
                    self.assertTrue(np.sum(np.abs(t)) != 0)
                    base_map[var.name] = t
            # test for replace_save_vars/io.load_vars
            path_vars1 = os.path.join(
                self.temp_dir.name, 'test_replace_save_load_vars_binary1/model'
            )
            self.replace_save_vars(prog, path_vars1)
            # set var to zero
            self.set_zero(prog, place)
            var_list = list(
                filter(lambda var: var.persistable, prog.list_vars())
            )
            paddle.static.io.load_vars(
                exe, path_vars1, main_program=prog, vars=var_list
            )

            for var in prog.list_vars():
                if var.persistable and not is_pir_fetch_var(var):
                    new_t = np.array(
                        base.global_scope().find_var(var.name).get_tensor()
                    )
                    base_t = base_map[var.name]

                    np.testing.assert_array_equal(new_t, base_t)
            # test for io.save_vars/replace_load_vars
            path_vars2 = os.path.join(
                self.temp_dir.name, 'test_replace_save_load_vars_binary2/model/'
            )
            paddle.static.io.save_vars(
                exe, path_vars2, main_program=prog, vars=var_list
            )
            self.set_zero(prog, place)
            self.replace_load_vars(prog, path_vars2)
            for var in prog.list_vars():
                if var.persistable and not is_pir_fetch_var(var):
                    new_t = np.array(
                        base.global_scope().find_var(var.name).get_tensor()
                    )
                    base_t = base_map[var.name]

                    np.testing.assert_array_equal(new_t, base_t)

    def test_save_load_dense_tensor(self):
        paddle.enable_static()
        OUTPUT_NUM = 32
        with new_program_scope():
            x = paddle.static.data(
                name="x", shape=[None, IMAGE_SIZE], dtype='float32'
            )
            y = paddle.static.nn.fc(
                x,
                OUTPUT_NUM,
                name='fc_vars',
            )
            prog = paddle.static.default_main_program()
            place = get_device_place()
            exe = base.Executor(place)
            prog = paddle.static.default_main_program()
            exe.run(base.default_startup_program())

            dirname = os.path.join(
                self.temp_dir.name, 'test_save_load_dense_tensor1/tensor_'
            )
            for var in prog.list_vars():
                if var.persistable and list(var.shape) == [
                    IMAGE_SIZE,
                    OUTPUT_NUM,
                ]:
                    tensor = get_value(var)
                    paddle.save(
                        tensor, dirname + 'fc_vars.w_0', use_binary_format=True
                    )
                    break

            origin = np.array(get_value(var))
            set_value(var, np.zeros_like(origin))
            is_zeros = np.array(get_value(var))

            loaded_tensor = paddle.load(dirname + 'fc_vars.w_0')
            self.assertTrue(isinstance(loaded_tensor, base.core.DenseTensor))
            self.assertTrue(
                list(loaded_tensor.shape()) == [IMAGE_SIZE, OUTPUT_NUM]
            )
            to_array = np.array(loaded_tensor)
            np.testing.assert_array_equal(origin, to_array)

        with self.assertRaises(NotImplementedError):
            path = os.path.join(self.temp_dir.name, 'test_save_load_error/temp')
            paddle.save({}, path, use_binary_format=True)
        # On the Windows platform, when parsing a string that can't be parsed as a `Program`, `desc_.ParseFromString` has a timeout risk.
        if 'Windows' != platform.system():
            with self.assertRaises(ValueError):
                path = os.path.join(
                    self.temp_dir.name, 'test_save_load_error/temp'
                )
                with open(path, "w") as f:
                    f.write('\0')
                paddle.load(path)

        with self.assertRaises(ValueError):
            temp_lod = base.core.DenseTensor()
            paddle.save(temp_lod, path, use_binary_format=True)

        with self.assertRaises(RuntimeError):
            base.core.save_dense_tensor(
                temp_lod,
                os.path.join(
                    self.temp_dir.name,
                    'test_save_load_error_not_exist_file/not_exist_file',
                ),
            )

        with self.assertRaises(RuntimeError):
            base.core.load_dense_tensor(
                temp_lod,
                os.path.join(
                    self.temp_dir.name,
                    'test_save_load_error_not_exist_file/not_exist_file',
                ),
            )

        # save to memory
        byio = BytesIO()
        paddle.save(tensor, byio, use_binary_format=True)
        byio.seek(0)
        # load from memory
        loaded_tensor_mem = paddle.load(byio)
        to_array_mem = np.array(loaded_tensor_mem)
        np.testing.assert_array_equal(np.array(tensor), to_array_mem)

        with self.assertRaises(NotImplementedError):
            paddle.framework.io._save_dense_tensor(tensor, 1)
        with self.assertRaises(NotImplementedError):
            paddle.framework.io._load_dense_tensor(1)

    def test_save_load_selected_rows(self):
        paddle.enable_static()
        place = get_device_place()
        height = 10
        rows = [0, 4, 7]
        row_numel = 12
        selected_rows = base.core.SelectedRows(rows, height)
        path = os.path.join(
            self.temp_dir.name, 'test_paddle_save_load_selected_rows/sr.pdsr'
        )

        with self.assertRaises(ValueError):
            paddle.save(selected_rows, path, use_binary_format=True)

        np_array = np.random.randn(len(rows), row_numel).astype("float32")
        tensor = selected_rows.get_tensor()
        tensor.set(np_array, place)

        paddle.save(selected_rows, path, use_binary_format=True)
        load_sr = paddle.load(path)

        self.assertTrue(isinstance(load_sr, base.core.SelectedRows))
        self.assertTrue(list(load_sr.rows()) == rows)
        self.assertTrue(load_sr.height() == height)
        np.testing.assert_array_equal(np.array(load_sr.get_tensor()), np_array)

        with self.assertRaises(RuntimeError):
            base.core.save_selected_rows(
                selected_rows,
                os.path.join(
                    self.temp_dir.name,
                    'test_paddle_save_load_selected_rows_not_exist_file/temp',
                ),
            )
        with self.assertRaises(RuntimeError):
            base.core.load_selected_rows(
                selected_rows,
                os.path.join(
                    self.temp_dir.name,
                    'test_paddle_save_load_selected_rows_not_exist_file/temp',
                ),
            )

        # save to memory
        byio = BytesIO()
        paddle.save(selected_rows, byio, use_binary_format=True)
        byio.seek(0)
        # load from memory
        selected_rows_mem = paddle.load(byio)
        to_array_mem = np.array(selected_rows_mem)
        self.assertTrue(isinstance(selected_rows_mem, base.core.SelectedRows))
        self.assertTrue(list(selected_rows_mem.rows()) == rows)
        self.assertTrue(selected_rows_mem.height() == height)
        np.testing.assert_array_equal(
            np.array(selected_rows_mem.get_tensor()), np_array
        )

        with self.assertRaises(NotImplementedError):
            paddle.framework.io._save_selected_rows(selected_rows, 1)
        with self.assertRaises(NotImplementedError):
            paddle.framework.io._load_selected_rows(1)


if __name__ == '__main__':
    unittest.main()
